{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6sMznqOhL3G_"
      },
      "source": [
        "# ECG arrhythmia identification through time using the LSTM model\n",
        "The following notebooks labels ECG signal as either normal, left bundle branch blockage, right bundle branch blockage, premature ventrical contraction, or atrial premature beat using machine learning and the tensorflow library. It used the open souce MIT-BIH Arrhythmia Database containing a .csv file for raw ECG signal and a .txt file for annotations.\n",
        "\n",
        "Database: https://www.kaggle.com/datasets/mondejar/mitbih-database/data\n",
        "\n",
        "Database information: https://www.physionet.org/content/mitdb/1.0.0/\n",
        "\n",
        "Previous work has used the CNN model identification but in this code we will use the LSTM model as it better handles time squence data as an ECG raw signal would be."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-hpCuOICM_l9"
      },
      "source": [
        "## Install for the libraries not in colab\n",
        "This is a step needed if you don't have these libraries or if running the code for the first time in colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "B2W_reX3tXvC",
        "outputId": "aa9f19f2-a795-4c3b-8527-6738b4f487b7"
      },
      "outputs": [],
      "source": [
        "%pip install opendatasets\n",
        "%pip install numpy\n",
        "%pip install tensorflow\n",
        "%pip install scipy\n",
        "%pip install PyWavelets\n",
        "%pip install matplotlib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hghP4KfrNmIV"
      },
      "source": [
        "##Import necessary libraries for data processing and the machine learning model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "jMZh93_TrWvu"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "2024-04-14 18:04:30.511462: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2\n",
            "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
          ]
        }
      ],
      "source": [
        "#Imports\n",
        "import csv\n",
        "import os\n",
        "import opendatasets as od\n",
        "from matplotlib import pyplot as plt\n",
        "import random\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import pywt                                                                     # for wave transform and denoising\n",
        "from scipy import stats                                                         # for normalizing the signal"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8c9lGIabPqQM"
      },
      "source": [
        "## Load the data set from kaggle and assign to a variable\n",
        "\n",
        "Must input your own kaggle username and password."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NQBBWVsSrW6L",
        "outputId": "ee3f492e-f209-4b47-ac10-279b6ef00af0"
      },
      "outputs": [],
      "source": [
        "# Load dataset (must put in kaggle username and password)\n",
        "url = \"https://www.kaggle.com/datasets/mondejar/mitbih-database\"\n",
        "od.download(url)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EVK95B4FtfGw"
      },
      "outputs": [],
      "source": [
        "#Grabbing the dataset after being downloaded\n",
        "dataset_url = \"/content/mitbih-database/mitbih_database\"\n",
        "dataset = os.listdir(dataset_url)\n",
        "dataset.sort()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y-gLTgoPQRfW"
      },
      "source": [
        "## Read the csv file and process the data\n",
        "The data is separated why two different arrays one for the dataframe (Full ECG signal) and the annotaions (data slice ranges, labels for range)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZaxC1TGsLsUp"
      },
      "source": [
        "The following code cell is a helper function to denoise the ECG signal chunk using wavelets. This function is taken from SAI JITHESH on kaggle who processed and used the data in a CNN model. https://www.kaggle.com/code/abhibasavapattana/eyegaze-classification-using-cnn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JB7fvRppotFe"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "Denoise the signal data using the wavelet sym4 label, threshold of 0.4,\n",
        "and found coefficients.\n",
        "\n",
        "input: list of signal data\n",
        "return: list of signal data\n",
        "'''\n",
        "def denoise(data):\n",
        "    wavelet_funtion = 'sym3'                                                      #found to be the best function for ECG \n",
        "\n",
        "    w = pywt.Wavelet(wavelet_funtion)\n",
        "    maxlev = pywt.dwt_max_level(len(data), w.dec_len)\n",
        "    threshold = 0.1                                                               # Threshold for filtering the higher the closer to the wavelet (less noise)\n",
        "\n",
        "    coeffs = pywt.wavedec(data, wavelet_funtion, level=maxlev)\n",
        "    for i in range(1, len(coeffs)):\n",
        "        coeffs[i] = pywt.threshold(coeffs[i], threshold*max(coeffs[i]))\n",
        "\n",
        "    datarec = pywt.waverec(coeffs, wavelet_funtion)\n",
        "    return datarec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dAvjkw6wvqR1"
      },
      "outputs": [],
      "source": [
        "dataframe = []\n",
        "annotations = []\n",
        "for file in dataset:\n",
        "  signal = []\n",
        "  if \".csv\" in file:\n",
        "    with open(dataset_url + '/' + file, newline= '') as csvfile:\n",
        "      read_csv = csv.reader(csvfile, delimiter=',')\n",
        "      next(read_csv, None)  # skip the headers\n",
        "      for row in read_csv:\n",
        "        signal.append(int(row[2]))\n",
        "    signal = denoise(signal)\n",
        "    signal = stats.zscore(signal)\n",
        "    dataframe.append(signal)\n",
        "\n",
        "  if \".txt\" in file:\n",
        "    annotation = []\n",
        "    with open(dataset_url + '/' + file, \"r\") as txtfile:\n",
        "      rows = txtfile.readlines()\n",
        "      for row in range(1,len(rows)): # txtfile:\n",
        "        row = rows[row].split()\n",
        "        annotation.append([row[1],row[2]])  # (index, label)\n",
        "        # print([row[1],row[2]])\n",
        "    annotations.append(annotation)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DLKvyBgZ72y0"
      },
      "source": [
        "Build lists to hold the chunks of ECG data and labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LDvv7uP4eyGp",
        "outputId": "07f62ecb-c43c-4f5b-dd33-d3194ce7f5e9"
      },
      "outputs": [],
      "source": [
        "#empty list to hold the chunks of ECG data and labels\n",
        "datachunk = [] #[[chunk1] ,[chunk2]]\n",
        "labels = [] #[[label1], [label2]]\n",
        "\n",
        "#initialize max length\n",
        "maxlength=0\n",
        "maxlenlist = []\n",
        "\n",
        "#loop through the list of annotations\n",
        "for i in range(0,len(annotations)):\n",
        "  # label = []\n",
        "  print(i)\n",
        "  for j in range(0, len(annotations[i])):\n",
        "    #try the following code to build the list (NOTE: we try because there is an error at the last value)\n",
        "    try:\n",
        "      #check if the label is in the label list (the data set has a few outlying labels outside the ones needed)\n",
        "      if(annotations[i][j][1] in ['N', 'L', 'R', 'V', 'A']):\n",
        "        #get the signal chunk based on the start and end points in annotations\n",
        "        start = int(annotations[i][j][0])\n",
        "        end = int(annotations[i][j+1][0])\n",
        "        datachunk.append(dataframe[i][start : end])                                # append data chunk to the datachunk list\n",
        "\n",
        "        #get the max length of data chunk for paddign\n",
        "        newlength = end - start\n",
        "        maxlenlist.append(newlength)\n",
        "\n",
        "        if newlength > maxlength:\n",
        "          maxlength = newlength\n",
        "          # print(newlength)\n",
        "          # maxlenlist.append(newlength)\n",
        "\n",
        "        #map the labels to numbers and append to the list (where N = 0, L = 1, R = 2, V = 3, A = 4)\n",
        "        if (annotations[i][j][1] == 'N'):\n",
        "          labels.append(0)\n",
        "        elif (annotations[i][j][1] == 'L'):\n",
        "          labels.append(1)\n",
        "        elif (annotations[i][j][1] == 'R'):\n",
        "          labels.append(2)\n",
        "        elif (annotations[i][j][1] == 'V'):\n",
        "          labels.append(3)\n",
        "        elif (annotations[i][j][1] == 'A'):\n",
        "          labels.append(4)\n",
        "\n",
        "    except:\n",
        "      #if the code fails save the following error message (NOTE: we are choosing to ignor the final data point so we don't care if it fails)\n",
        "      error_message = \"invalid end value %d\" %end                                  #this can be changed to be printed if the user wants\n",
        "    # labels.append(label)\n",
        "\n",
        "print(maxlength)\n",
        "print(max(maxlenlist))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7mjLGlVp6zcK"
      },
      "source": [
        "## Split the list of chunked data into test, train, and validation sets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K3QiBJDV11za"
      },
      "outputs": [],
      "source": [
        "#generate random number list the length of dataset\n",
        "shuffle_list = random.sample(range(0,len(datachunk)), len(datachunk))\n",
        "\n",
        "#make the sizes for the different data set (sections)\n",
        "train_size = 4*len(datachunk)//5\n",
        "test_size = int(0.1 * len(datachunk) + train_size)\n",
        "val_size  = int(0.1 * len(datachunk) + test_size)\n",
        "\n",
        "#get the test, train, validation data sets from random section of dataset\n",
        "train_set_datachunk = [datachunk[i] for i in shuffle_list[0:train_size]] #with labels\n",
        "test_set_datachunk = [datachunk[i] for i in shuffle_list[train_size:test_size]]\n",
        "val_set_datachunk = [datachunk[i] for i in shuffle_list[test_size:val_size]]\n",
        "\n",
        "train_set_labels = [labels[i] for i in shuffle_list[0:train_size]] #with labels\n",
        "test_set_labels = [labels[i] for i in shuffle_list[train_size:test_size]]\n",
        "val_set_labels = [labels[i] for i in shuffle_list[test_size:val_size]]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xKWF41GszjXm"
      },
      "outputs": [],
      "source": [
        "#pad the datachunks to make the signals all the same length (500) to be updates\n",
        "train_set_datachunk = tf.keras.preprocessing.sequence.pad_sequences(train_set_datachunk, maxlength)\n",
        "test_set_datachunk = tf.keras.preprocessing.sequence.pad_sequences(test_set_datachunk, maxlength)\n",
        "val_set_datachunk = tf.keras.preprocessing.sequence.pad_sequences(val_set_datachunk, maxlength)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DVh-eZHQ-rW6"
      },
      "outputs": [],
      "source": [
        "#convert to numpy array (make this cleaner)\n",
        "train_set_datachunk = np.array(train_set_datachunk) #[chunks1, chunk2, ...]\n",
        "test_set_datachunk  = np.array(test_set_datachunk)\n",
        "val_set_datachunk = np.array(val_set_datachunk)\n",
        "\n",
        "train_set_labels = np.array(train_set_labels)\n",
        "test_set_labels  = np.array(test_set_labels)\n",
        "val_set_labels = np.array(val_set_labels)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "#save the files (as needed)\n",
        "np.savetxt(\"train_set_datachunk.csv\",train_set_datachunk , delimiter=\",\")\n",
        "np.savetxt(\"test_set_datachunk.csv\", test_set_datachunk, delimiter=\",\")\n",
        "np.savetxt(\"val_set_datachunk.csv\", val_set_datachunk, delimiter=\",\")\n",
        "\n",
        "np.savetxt(\"train_set_labels.csv\",train_set_labels , delimiter=\",\")\n",
        "np.savetxt(\"test_set_labels.csv\", test_set_labels, delimiter=\",\")\n",
        "np.savetxt(\"val_set_labels.csv\", val_set_labels, delimiter=\",\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "o4tKB-jbCsC7",
        "outputId": "80f95c91-c2a4-4d3d-e595-a69ec674830e"
      },
      "outputs": [],
      "source": [
        "#for testing purposes and understanding the data chunk shape and the label shape\n",
        "print(train_set_datachunk.shape)\n",
        "print(train_set_labels.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Nr2mfC-tFp1j",
        "outputId": "a661809a-f24b-4f45-914f-2cb9b330d09a"
      },
      "outputs": [],
      "source": [
        "#this is for testing\n",
        "print(len(train_set_datachunk))\n",
        "print(len(test_set_datachunk))\n",
        "print(len(val_set_datachunk))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Wj7XkSlu3Mc2",
        "outputId": "b156a752-cc2a-46bd-f28e-4eb63959b65e"
      },
      "outputs": [],
      "source": [
        "#for testing purposes and understanding the data chunk shape\n",
        "\n",
        "#get the max length\n",
        "maxlength = 0\n",
        "index = 0\n",
        "for i in range(len(datachunk)):\n",
        "  newlength = len(datachunk[i])\n",
        "  if newlength > maxlength:\n",
        "    maxlength = newlength\n",
        "    index = i\n",
        "\n",
        "print(maxlength, index, labels[index])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Y27OWz5XgjR"
      },
      "source": [
        "## Plot each data chuncks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "aVuS-0Y9E-UD",
        "outputId": "6eac721e-0b38-4a8a-e75a-95c1ac319d4c"
      },
      "outputs": [],
      "source": [
        "#loop through the list of annotations\n",
        "for i in range(1,len(annotations)):\n",
        "  #set the title to the lable\n",
        "  plt.title(labels[i])\n",
        "  #plot the data for the data sections\n",
        "  plt.plot(datachunk[i]) #dataframe[int(annotations[i][0]):int(annotations[i+1][0])])\n",
        "  #show the plot\n",
        "  plt.show()\n",
        "  #close the plot so we can show the next section\n",
        "  plt.close()\n",
        "\n",
        "# plt.plot(dataframe[188561:188904])\n",
        "# plt.plot(dataframe[188904:189199])\n",
        "# plt.plot(dataframe[189199:189423])\n",
        "# ['188561', 'N']\n",
        "# ['188561', 'N']\n",
        "# ['188904', 'V']\n",
        "# ['189199', 'N']\n",
        "# ['189423', 'N']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QGOD1HUV7THY"
      },
      "source": [
        "## Build the LSTM model\n",
        "Using the specific model variables of batch_size, epochs, units, input_dim,\n",
        "sample_size, and time_step.  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Cpv7flWrZQR0",
        "outputId": "70727a5e-07fe-4b73-8dbc-04dd61b3b54f"
      },
      "outputs": [
        {
          "ename": "NameError",
          "evalue": "name 'train_set_datachunk' is not defined",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
            "Input \u001b[0;32mIn [2]\u001b[0m, in \u001b[0;36m<cell line: 6>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      4\u001b[0m units \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m5\u001b[39m                                                                        \u001b[38;5;66;03m# number of labels\u001b[39;00m\n\u001b[1;32m      5\u001b[0m input_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m                                                                    \u001b[38;5;66;03m# number of features\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m sample_size \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_set_datachunk\u001b[49m\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]                                       \u001b[38;5;66;03m# number of total ECG samples\u001b[39;00m\n\u001b[1;32m      7\u001b[0m time_step \u001b[38;5;241m=\u001b[39m train_set_datachunk\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]                                         \u001b[38;5;66;03m# length of the ECG chunk\u001b[39;00m\n\u001b[1;32m      8\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m (batch_size, time_step, input_dim)\n",
            "\u001b[0;31mNameError\u001b[0m: name 'train_set_datachunk' is not defined"
          ]
        }
      ],
      "source": [
        "#initialize all model specific variables\n",
        "batch_size = 151                                                                 # make this a number divisible by the total number of samples\n",
        "epochs = 10\n",
        "units = 5                                                                        # number of labels\n",
        "input_dim = 1                                                                    # number of features\n",
        "sample_size = train_set_datachunk.shape[0]                                       # number of total ECG samples\n",
        "time_step = train_set_datachunk.shape[1]                                         # length of the ECG chunk\n",
        "input_shape = (batch_size, time_step, input_dim)\n",
        "\n",
        "#define the model\n",
        "model = tf.keras.Sequential([\n",
        "    #add the layers of the model\n",
        "    #tf.keras.layers.Lambda(lambda x: tf.reshape(x,input_shape)),                 # reshape layer to make the input into LSTM layer 3D -> expected input data shape for LSTM: (batch_size, timesteps, data_dim)\n",
        "    # tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(batch_size, return_sequences=True, dropout = 0.2)),      # returns a sequence of vectors of dimension batch_size\n",
        "    tf.keras.layers.LSTM(batch_size, input_shape = (time_step, input_dim), activation='relu', return_sequences=True, dropout = 0.2),      # returns a sequence of vectors of dimension batch_size\n",
        "    tf.keras.layers.LSTM(batch_size, input_shape = (time_step, input_dim), activation='relu', return_sequences=True, dropout = 0.2),      # returns a sequence of vectors of dimension batch_size\n",
        "    tf.keras.layers.LSTM(batch_size, input_shape = (time_step, input_dim), activation='relu', return_sequences=True, dropout = 0.2),      # returns a sequence of vectors of dimension batch_size\n",
        "    tf.keras.layers.LSTM(batch_size, input_shape = (time_step, input_dim), activation='relu', return_sequences=True, dropout = 0.2),      # returns a sequence of vectors of dimension batch_size\n",
        "    tf.keras.layers.LSTM(batch_size, input_shape = (time_step, input_dim), dropout = 0.2),                             # returns 1xbatch_size\n",
        "    tf.keras.layers.Dense(units, activation = \"softmax\")                        # softmax for multiclass labeling\n",
        "\n",
        "])\n",
        "\n",
        "#build the model passing int the input shape\n",
        "# model.build((input_shape))\n",
        "\n",
        "#print the model summary\n",
        "print(model.summary())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MxjcKjwFYpFq"
      },
      "source": [
        "Compile the model specifying the loss and optimizer functions and the elvuation metrics for accuracy. Using the sparse categorical crossentropy method for the loss function to achieve multiple numerical matchings of the labels to the ECG singal chunks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XJe5jCQyqJIC"
      },
      "outputs": [],
      "source": [
        "#compile the model\n",
        "model.compile(loss = 'sparse_categorical_crossentropy', optimizer='adam', metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) #loss for numerical and multiple matching"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iBG966FPYRm0"
      },
      "source": [
        "## Train the model by fitting the LSTM model to the train data sets\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u-QS-L3MNiz8"
      },
      "source": [
        "Save checkpoints during training. Using the tf.keras.callbacks.ModelCheckpoint we can save the model both during and at the end of training.\n",
        "This allows us to use trained model without having to retrain it, or pick-up training where you left off in case the training process was interrupted.\n",
        "\n",
        "NOTE: it is needed to load in the model to use the saved version."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xCVbbz0IMtSp"
      },
      "outputs": [],
      "source": [
        "#initialize variables for the file path and directory of where to save the check points\n",
        "checkpoint_path = \"training_1/cp.ckpt\"\n",
        "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
        "\n",
        "# Create a callback that saves the model's weights\n",
        "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n",
        "                                                 save_best_only=True,\n",
        "                                                 verbose=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zaFXKYltOoAl"
      },
      "source": [
        "Time to train the model with the training data set and assign the output to new variable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 382
        },
        "id": "DH3Sexg3qiXf",
        "outputId": "819cda44-6e6b-451e-fcf0-2a8fb6038172"
      },
      "outputs": [],
      "source": [
        "#train the model (where iterations is total num smaples/ batchsize )\n",
        "train_model = model.fit(train_set_datachunk, train_set_labels, epochs = epochs, batch_size = batch_size, callbacks=[cp_callback], validation_data=(test_set_datachunk,test_set_labels))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o03C7Q1DX_Mc"
      },
      "source": [
        "## Test the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-h9ox6KCHBut",
        "outputId": "3bcf09e5-8730-4656-f71d-6b239ec4286e"
      },
      "outputs": [],
      "source": [
        "# Loads the weights\n",
        "model.load_weights(checkpoint_path)\n",
        "\n",
        "#test the model with a higher batch size\n",
        "loss, acc = model.evaluate(test_set_datachunk, test_set_labels, verbose = 0) # make this a higher batchsize but based on a common factor between test and train set size"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c5gyasvP_8Rs"
      },
      "outputs": [],
      "source": [
        "# Save the entire model as a `.keras` zip archive.\n",
        "model.save('my_model.keras')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Convert the model to the TensorFlow Lite format without quantization\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] \n",
        "converter._experimental_lower_tensor_list_ops = False\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "# Save the model to disk\n",
        "open(\"ECG_LSTM_model.tflite\", \"wb\").write(tflite_model)\n",
        "  \n",
        "basic_model_size = os.path.getsize(\"gesture_model.tflite\")\n",
        "print(\"Model is %d bytes\" % basic_model_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ebFf2R0L_7ND"
      },
      "outputs": [],
      "source": [
        "#load in the model (if needed for later steps)\n",
        "new_model = tf.keras.models.load_model('/Users/zoeboysen/Desktop/my_model.keras')       # set safe mode to false to make sure your computer trusts the load"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CpLXjQDfYiCG"
      },
      "source": [
        "## Validate the model and show the evaluation metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 472
        },
        "id": "Is--varmXt_J",
        "outputId": "20a0d66f-06b0-4e47-d618-a30cf692837c"
      },
      "outputs": [],
      "source": [
        "#plot the accuracy \n",
        "print(train_model.history.keys())\n",
        "plt.title(\"Testing and Training Accuracy\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.xlabel(\"Epoch\")\n",
        "plt.plot(train_model.history[\"sparse_categorical_accuracy\"])\n",
        "plt.plot(train_model.history['val_sparse_categorical_accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QM4qovYMPUUO",
        "outputId": "78c642ea-8e8f-4107-f102-cae72186aeb9"
      },
      "outputs": [],
      "source": [
        "#plot the loss\n",
        "print(train_model.history.keys())\n",
        "plt.title(\"Testing and Training Loss\")\n",
        "plt.ylabel(\"Loss\")\n",
        "plt.xlabel(\"Epoch\")\n",
        "plt.plot(train_model.history[\"loss\"])\n",
        "plt.plot(train_model.history['val_loss'])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
